# This code is essentially the same as the one calculating the economic zone
# Please compare with 'economic_zone_calc' to see what elements are considered differently.
    # function determine_economic_zone: no noise is considered
    # bounds on T: the entire physical constraint
    # risk factor: very high for not limiting the upper bounds of the cost


# imports
using LightGraphs, IntervalArithmetic, SpatialIndexing, Plots, PyCall
using InvariantSetApproximation  # Obtain CIS directly using the package
import LazySets, Polyhedra

""" Using Economic Zone Approach """
# create tree 2D
function create_tree(xd::Vector{NTuple{2, Int64}}, x1d::Vector{Interval{Float64}}, x2d::Vector{Interval{Float64}})
    tree = RTree{Float64,2}(Int)
    @inbounds for (id, (i,j)) in enumerate(xd)
        rec = SpatialIndexing.Rect((x1d[i].lo,x2d[j].lo),(x1d[i].hi,x2d[j].hi))
        insert!(tree, rec, id)
    end
    return tree
end

# select cells with path to scc
function ancestors(G::SimpleDiGraph{Int64}, src)
    reverse!(G)
    a = Vector{Int64}()
    @inbounds for (v, d) in enumerate(gdistances(G, src))
        if d < typemax(Int64)
            push!(a, v)
        end
    end
    reverse!(G)
    return a
end

# select nonleaving cells
function select_nonleaving_cells(G::SimpleDiGraph{Int64})
	scc = strongly_connected_components(G)
	idx = 0
    len = 0
    @inbounds for i in 1:length(scc)
        l = length(scc[i])
        l > len && (idx = i; len=l) 
    end
    da = ancestors(G, scc[idx])
    return union(scc[idx],da)
end

# construct 2D boxes for plotting
function construct_face(xd::Vector{NTuple{2, Int64}},xd1::Vector{Interval{Float64}},xd2::Vector{Interval{Float64}})
    boxes = Vector{typeof(LazySets.Hyperrectangle(low=[0.0,0.0],high=[1.,1.]))}()
    for (i,j) in xd
        push!(boxes,LazySets.Hyperrectangle(low=[xd1[i].lo,xd2[j].lo],high=[xd1[i].hi,xd2[j].hi]))
    end
    return boxes
end

# create graph
function create_graph(xd::Vector{NTuple{2, Int64}}, x1d::Vector{Interval{Float64}}, x2d::Vector{Interval{Float64}}, ud::Vector{Tuple{Interval{Float64}}}, w, model)
    tree = create_tree(xd, x1d, x2d)
    G = SimpleDiGraph(length(xd))
    u = ud[1]
    @inbounds for (id,(i,j)) in enumerate(xd)
        x = (x1d[i],x2d[j])
        y = model(x,u,w)
        yr = SpatialIndexing.Rect((y[1].lo,y[2].lo),(y[1].hi,y[2].hi))
        hits = SpatialIndexing.intersects_with(tree, yr)
        @inbounds for hit in hits
            if hit.val != id
                add_edge!(G,id,hit.val)
            end
        end
    end
    return G
end

# economic cost
function economic_cost(x,u)
    # Equation (40)
    if x[2].hi > 352.0 
        return 1*x[1] + 10*(x[2].hi-352.0)^2
    elseif x[2].lo < 348.0
        return 1*x[1] + 10*(348.0-x[2].lo)^2
    else
        return 1*x[1]
    end

end

# determine economic zone
function determine_economic_zone(xd::Vector{NTuple{2, Int64}}, x1d::Vector{Interval{Float64}}, x2d::Vector{Interval{Float64}}, ud::Vector{Tuple{Interval{Float64}}}, w::Vector{Interval{Float64}}, delta::Float64, M::Float64)
    Ce = Vector{NTuple{2, Int64}}()  # A1L2. Initialize an empty Ce. Ce is a subset of xd
    u = ud[1]
    for B in xd  # A1L3-L5
        x = [x1d[B[1]], x2d[B[2]]]
        # e = collect(model(x,u,w)) - collect(model(x,u,[0.0,0.0]))
        # println(e)
        le = economic_cost(x,u) # - Remove the effect of noise/disturbance
        # println(le)
        if  le.hi <= delta
            push!(Ce,B)
        end 
    end
    return Ce
end


function computeEconomicZone(d, delta)

    # example test
    # Create intervals for computation
    x1 = @interval(0.0,1.)   # no zone in concentration. use entire state constraint
    x2 = @interval(348.5,351.5) # search within the entire temperature constraint
    u1 = @interval(285.0,315.0)
    w1 = @interval(-0.1,0.1)  # nominal value 1
    w2 = @interval(-2.0,2.0)  # nominal value 350

    x1s = mince(x1,d)
    x2s = mince(x2,d)
    u1s = mince(u1,1)
    
    # CSTR model parameters
    nx = 2					# Number of states
    nu = 1					# Number of inputs
    nzx = 1					# Number of zone states
    nzu = 0					# Number of zone inputs
    nw = 2					# Number of disturbances
    q = 100.0				# inlet flow rate, L/min
    Tin = 350.0				# inlet temperature, K
    Cin = 1.0				# inlet concentration, kmol/m^3
    V = 100.0				# Volume of reactor, L
    # r = 0.219				# radius of reactor, m
    k0 = 7.2E10				# rate constant, 1/min
    EoR = 8750.0			# Activate energy/gas constant, K
    UA = 5.0E4				# heat transfer coefficient, J/min.K
    rho = 1000.0				# density, g/L
    Cp = 0.239				# heat capacity, J/g.K
    DH = -5.0E4				# heat of reaction, J/mol
    T = 0.1					# discretization step size, min

    f1de(x::Tuple{Interval{Float64},Interval{Float64}},u::Tuple{Interval{Float64}},w) = (x[1] + T*(( (q/V)*(Cin + w[1] - x[1]) - k0*exp(-EoR/x[2])*x[1] ) ), x[2] + T*( ( (q/V)*(Tin + w[2] - x[2]) + (-DH/(rho*Cp))*k0*exp(-EoR/x[2])*x[1] + (UA/(V*rho*Cp))*(-x[2]) ) +  (UA/(V*rho*Cp))*u[1] ))
  
    # create convexhull of target zone
    x12d = [(i,j) for i in 1:length(x1s) for j in 1:length(x2s)]
    # println(length(x12d))
    Xt = LazySets.ConvexHullArray(construct_face(x12d,x1s,x2s))
    u1d = [(u,) for u in u1s];
    w = [w1,w2]

    # Step 1: Determine economic zone within the target zone
    Ce = determine_economic_zone(x12d,x1s,x2s,u1d,w,delta,10.0)  # Ce is a subset of x12d
    println(length(Ce))
    Xe = LazySets.ConvexHullArray(construct_face(Ce,x1s,x2s))

    # Step 2: Determine cells that approximate the robust control invariant set  # A1L6-L12
    # Using condition in algorithm 2 is equivalent to increasing the size of the disturbance set
    e = (0.005,0.03)
    W = [(0.0,0.0),(-0.1-e[1],-2.0-e[2]),(0.1+e[1],-2.0-e[2]),(-0.1-e[1],2.0+e[2]),(0.1+e[1],2.0+e[2])]

    Xes = []
    
    for w in W
        G = create_graph(Ce, x1s, x2s, u1d, w, f1de)
        nlc = select_nonleaving_cells(G)
        push!(Xes,nlc)
    end

    # find intersections
    Cer = Xes[1]
    for xe in Xes
        Cer = intersect(Cer,xe)
    end

    Cr = Ce[Cer]

    @show x1s[1]
    @show x2s[1]

    Xr = LazySets.ConvexHullArray(construct_face(Cr,x1s,x2s))

    return Xr,Xe,Xt

end

Xr,Xe,Xt = computeEconomicZone(256,1000.0)  # (initial cell diameter, risk factor). Increase the risk factor to eliminate its effect on searching economic zone

Vr = LazySets.constraints_list(LazySets.remove_redundant_constraints(LazySets.convert(LazySets.HPolytope,LazySets.VPolytope(LazySets.vertices_list(Xr)))))
Ve = LazySets.constraints_list(LazySets.remove_redundant_constraints(LazySets.convert(LazySets.HPolytope,LazySets.VPolytope(LazySets.vertices_list(Xe)))))
Vt = LazySets.constraints_list(LazySets.remove_redundant_constraints(LazySets.convert(LazySets.HPolytope,LazySets.VPolytope(LazySets.vertices_list(Xt)))))

# """ Calculate CIS using InvariantSetApproximation package """
# # CSTR model parameters
# nx = 2					# Number of states
# nu = 1					# Number of inputs
# nzx = 1					# Number of zone states
# nzu = 0					# Number of zone inputs
# nw = 2					# Number of disturbances
# q = 100.0				# inlet flow rate, L/min
# Tin = 350.0				# inlet temperature, K
# Cin = 1.0				# inlet concentration, kmol/m^3
# V = 100.0				# Volume of reactor, L
# # r = 0.219				# radius of reactor, m
# k0 = 7.2E10				# rate constant, 1/min
# EoR = 8750.0			# Activate energy/gas constant, K
# UA = 5.0E4				# heat transfer coefficient, J/min.K
# rho = 1000				# density, g/L
# Cp = 0.239				# heat capacity, J/g.K
# DH = -5.0E4				# heat of reaction, J/mol
# T = 0.1                 # discretization step size, min
# # h = 0.659				# level of mixture, m
# function model(x::Vector{Float64},u::Vector{Float64},w::Vector{Float64})
# 	C = x[1]; Temp = x[2]; Tc = u[1]
# 	return [C + T*((q/V)*(Cin + w[1] - C) - k0*exp(-EoR/Temp)*C),
# 			Temp + T*( (q/V)*(Tin + w[2] - Temp) + (-DH/(rho*Cp))*k0*exp(-EoR/Temp)*C + (UA/(V*rho*Cp))*(Tc - Temp) )]
# end

# function main(model::Function, iter::Int)
#     # state, input and uncertainty bounds
#     Xub = [1., 352.]
#     Xlb = [0., 348.]
#     Uub = [315.]
#     Ulb = [285.]
#     # diameter of cell at 14 iterations is roughly 0.11
#     # enlarge W by at least 0.11 (we use 0.15 here) to disturbance to satisfy theorem 4 at the 14th subdivision
#     Wub = [0.1, 2.] .+ [0.005, 0.03]
#     Wlb = [-0.1, -2.] .+ -[0.005, 0.03]

#     # system
#     S = system(model, Xlb, Xub, ulb=Ulb, uub=Uub, wlb=Wlb, wub=Wub)

#     # options
#     options = Dict(:max_step=>iter, :nXsamples=>5, :Xsampletype=>:face, :nUsamples=>5, :nWsamples=>2)
#     O = Options(options)

#     # computation
#     sol = computeISet(S,O)

#     return sol
# end

# # run main
# sol = main(model, 21);
# # find convexhull of set (for faster viewing of solution)
# iset_cvxh = ConvexHullArray(sol.iset)
# # find half space representation
# iset_hrep = LazySets.constraints_list(LazySets.remove_redundant_constraints(LazySets.convert(LazySets.HPolytope,LazySets.VPolytope(LazySets.vertices_list(iset_cvxh)))))
plot(Xr)
# Extract a and b from Vr and save it in numpy format
@show Vr
np = pyimport("numpy")
hrepABig = Vector()
hrepBBig = Vector()
for i in 1:length(Vr)
    push!(hrepABig, Vr[i].a)
    push!(hrepBBig, Vr[i].b)
end
hrepABignp = np.asarray(hrepABig)
hrepBBignp = np.asarray(hrepBBig)
np.save("results/hrepAPM0_5.npy", hrepABignp)
np.save("results/hrepBPM0_5.npy", hrepBBignp)

# # Episode 479 -----------------------------------------------------------------
# plot(Xt)
# plot!(Xe, label="Hard Constraint", yticks = 345:2:355, xlabel="C_A", ylabel="T", title="episode 479")
# plot!(Xr, label="CIS")
# # # Reward3
# # plot!([0.68129269, 0.64407758, 0.60764932, 0.57311953, 0.5408325, 0.51122743], [350.19626329, 351.57201625, 352.73635348, 353.80287968, 354.73881575, 355.60032376], label="episode 9998", arrow=true, markershape=[:circle])
# # plot!([0.72872944, 0.69672767, 0.64605488, 0.59054793], [347.09921032, 352.1227084, 354.84785636, 358.79010748], label="episode 9994", arrow=true, markershape=[:diamond])
# # plot!([2.22541853e-01, 2.70086896e-01, 3.20775829e-01, 0.36561357], [3.54328554e+02, 3.47340796e+02, 3.45454915e+02, 344.35737571], label="episode 9977", arrow=true, markershape=[:star5])
# # # Reward4
# # plot!([2.17310136e-01, 2.66832392e-01], [3.53962277e+02, 3.46030761e+02], label="episode 9999", arrow=true, markershape=[:circle])

# # Compare reward functions
# plot!([0.75265496, 0.70942194],[348.57886074, 349.79654461], label="RL1, action=285.75", arrow=true, markershape=[:circle])
# plot!([0.75265496, 0.70942194],[348.57886074, 349.84493221], label="RL2, action=285.98", arrow=true, markershape=[:diamond])
# plot!([0.75265496, 0.70942194],[348.57886074, 349.9200578], label="RL3, action=286.34", arrow=true, markershape=[:star5])
# plot!([0.75265496, 0.70942194, 0.66893766, 0.63106087],[348.57886074, 349.72210853, 350.83362371, 351.89615595], label="RL3_run2, action=285.40, 285.36, 285.33", arrow=true, markershape=[:utri])

# # Episode 61 ------------------------------------------------------------
# plot(Xt)
# plot!(Xe, label="Hard Constraint", yticks = 345:2:355, xlabel="C_A", ylabel="T", title="episode 61")
# plot!(Xr, label="CIS")
# # Compare reward functions
# plot!([2.08890038e-01, 2.60350929e-01],[3.53971199e+02, 3.45112189e+02], label="RL1, action=285.87", arrow=true, markershape=[:circle])
# plot!([2.08890038e-01, 2.60350929e-01],[3.53971199e+02, 3.46598395e+02], label="RL2, action=292.98", arrow=true, markershape=[:diamond])
# plot!([2.08890038e-01, 2.60350929e-01],[3.53971199e+02, 3.49655305e+02], label="RL3, action=307.59", arrow=true, markershape=[:star5])
# plot!([2.08890038e-01, 2.60350929e-01, 3.06858649e-01, 0.34794325,  0.3836238, 0.41386536, 0.43852779],
# [3.53971199e+02, 3.50747100e+02, 3.48836765e+02, 347.7169626, 347.21772034, 347.25021766, 347.76389842], label="RL3_run2, action=312.81, 314.52, 70, 71, 65, 44", arrow=true, markershape=[:utri])

# # RL 1 ------------------------------------------------------------
# # Import bad state transitions
# bad_states_idx = np.load("data/reward4_run4_rl2/bad_states_idx.npy", allow_pickle=true)
# bad_states_detected = np.load("data/reward4_run4_rl2/bad_states_detected.npy", allow_pickle=true)
# # plot
# figure1 = plot(Xt, yticks = 345:2:355, xlabel="C_A", ylabel="T", title="RL3-Run9")
# plot!(Xe, label="Hard Constraint")
# plot!(Xr, label="CIS")
# for (idx, item) in enumerate(bad_states_idx)
#     plot!([bad_states_detected[idx].T[1,:]],[bad_states_detected[idx].T[2,:]], label="epi "*string(item), arrow=true, markshape=[:auto])
# end
# figure1

# RL 2 ------------------------------------------------------------
# Import bad state transitions
# data_dir = "data/reward4_run4_rl2/"
# bad_states_idx = np.load(data_dir*"bad_states_idx.npy", allow_pickle=true)
# bad_states_detected = np.load(data_dir*"bad_states_detected.npy", allow_pickle=true)
# # plot
# figure2 = plot(Xt, yticks = 345:2:355, xlabel="C_A", ylabel="T", title="RL2-Run4", size=(750,450))
# plot!(Xe, label="Hard Constraint")
# plot!(Xr, label="CIS_use_eco_zone")
# plot!(iset_cvxh, label="CIS_use_pkg")
# for (idx, item) in enumerate(bad_states_idx)
#     plot!([bad_states_detected[idx].T[1,:]],[bad_states_detected[idx].T[2,:]], label="epi "*string(item), legend=:outertopright, arrow=true, markershape=[:auto])
# end
# figure2
# savefig(figure2, data_dir*"bad_states_transitions_plot.png")

# # RL 3 - Run9 ------------------------------------------------------------
# # Import bad state transitions
# data_dir = "data/reward1_run9_rl3/"
# bad_states_idx = np.load(data_dir*"bad_states_idx.npy", allow_pickle=true)
# bad_states_detected = np.load(data_dir*"bad_states_detected.npy", allow_pickle=true)
# # plot
# figure3 = plot(Xt, yticks = 345:2:355, xlabel="C_A", ylabel="T", title="RL3-Run9", size=(750,450))
# plot!(Xe, label="Hard Constraint")
# plot!(Xr, label="CIS_use_eco_zone")
# plot!(iset_cvxh, label="CIS_use_pkg")
# for (idx, item) in enumerate(bad_states_idx)
#     plot!([bad_states_detected[idx].T[1,:]],[bad_states_detected[idx].T[2,:]], label="epi "*string(item), legend=:outertopright, arrow=true, markershape=[:auto])
# end
# figure3
# savefig(figure3, data_dir*"bad_states_transitions_plot.png")


# # RL 3 - Run 2 ------------------------------------------------------------
# # Import bad state transitions
# data_dir = "data/reward1_run2_rl3/"
# bad_states_idx = np.load(data_dir*"bad_states_idx.npy", allow_pickle=true)
# bad_states_detected = np.load(data_dir*"bad_states_detected.npy", allow_pickle=true)
# # plot
# figure4 = plot(Xt, yticks = 345:2:355, xlabel="C_A", ylabel="T", title="RL3-Run2", size=(750,450))
# plot!(Xe, label="Hard Constraint")
# plot!(Xr, label="CIS_use_eco_zone")
# plot!(iset_cvxh, label="CIS_use_pkg")
# for (idx, item) in enumerate(bad_states_idx)
#     plot!([bad_states_detected[idx].T[1,:]],[bad_states_detected[idx].T[2,:]], label="epi "*string(item), legend=:outertopright, arrow=true, markershape=[:auto])
# end
# figure4
# savefig(figure4, data_dir*"bad_states_transitions_plot.png")

# println("done")